import argparse
import random
import platform

MIN_PRECISON = 1
from end_to_end_linalg_leveled_gen import P_ERROR

MAX_PRECISION = 57

TEST_ERROR_RATES = """\
test-error-rates:
  - global-p-error: 0.0001
    nb-repetition: 10000"""

PRECISIONS_WITH_ERROR_RATES = {
    1, 2, 3, 4, 9, 16, 24, 32, 57
}

def normalize_integer_precision(prec):
    if prec <= 8:
        return 8
    if prec <= 16:
        return 16
    if prec <= 32:
        return 32
    if prec <= 64:
        return 64
    raise Exception()

def main(args):
    print("# /!\ DO NOT EDIT MANUALLY THIS FILE MANUALLY")
    print("# /!\ THIS FILE HAS BEEN GENERATED THANKS THE end_to_end_levelled_gen.py scripts")
    print("# This reference file aims to test all levelled ops with all bitwidth than we known that the compiler/optimizer support.\n\n")
    # unsigned
    step = 1
    if args.minimal:
        step = 6
    for p in range(MIN_PRECISON, MAX_PRECISION+1, step):
        if p != 1:
            print("---")
        def may_check_error_rate():
            if p in PRECISIONS_WITH_ERROR_RATES:
                print(TEST_ERROR_RATES)
        max_value = (2 ** p) - 1
        integer_bitwidth = p + 1
        max_constant = min((2 ** (57-p)) - 1, max_value)

        # identity
        print("description: identity_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print("    return %arg0: !FHE.eint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        may_check_error_rate()
        print("---")
        # zero_tensor
        print("description: zero_tensor_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main() -> tensor<2x2x4x!FHE.eint<{0}>> {{".format(p))
        print(
            "    %0 = \"FHE.zero_tensor\"() : () -> tensor<2x2x4x!FHE.eint<{0}>>".format(p))
        print("    return %0: tensor<2x2x4x!FHE.eint<{0}>>".format(p))
        print("  }")
        print("tests:")
        print("  - outputs:")
        print("    - tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]")
        print("      shape: [2,2,4]")
        may_check_error_rate()
        print("---")
        # add_eint_int_cst
        print("description: add_eint_int_cst_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print("    %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
        print(
            "    %1 = \"FHE.add_eint_int\"(%arg0, %0): (!FHE.eint<{0}>, i{1}) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.eint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value-1))
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        may_check_error_rate()
        print("---")
        # add_eint_int_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: add_eint_int_arg_{0}bits".format(p))
            print("program: |")
            print(
                "  func.func @main(%arg0: !FHE.eint<{0}>, %arg1: i{1}) -> !FHE.eint<{0}> {{".format(p, integer_bitwidth))
            print(
                "    %0 = \"FHE.add_eint_int\"(%arg0, %arg1): (!FHE.eint<{0}>, i{1}) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
            print("    return %0: !FHE.eint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value-1))
            print("    - scalar: {0}".format(1))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("    - scalar: {0}".format(0))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("  - inputs:")
            print("    - scalar: {0}".format((max_value-1) >> 1))
            print("    - scalar: {0}".format((max_value >> 1) + 1))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            may_check_error_rate()
            print("---")
        # add_eint
        print("description: add_eint_{0}_bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print(
            "    %res = \"FHE.add_eint\"(%arg0, %arg1): (!FHE.eint<{0}>, !FHE.eint<{0}>) -> !FHE.eint<{0}>".format(p))
        print("    return %res: !FHE.eint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(((2 ** p) >> 1) - 1))
        print("    - scalar: {0}".format(((2 ** p) >> 1)))
        print("    outputs:")
        print("    - scalar: {0}".format((2 ** p) - 1))
        may_check_error_rate()
        print("---")
        # sub_eint_int_cst
        print("description: sub_eint_int_cst_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print("    %0 = arith.constant {1} : i{0}".format(
            integer_bitwidth, max_constant))
        print(
            "    %1 = \"FHE.sub_eint_int\"(%arg0, %0): (!FHE.eint<{0}>, i{1}) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.eint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("    outputs:")
        print("    - scalar: {0}".format(max_value-max_constant))
        may_check_error_rate()
        print("---")
        # sub_eint_int_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: sub_eint_int_arg_{0}bits".format(p))
            print("program: |")
            print(
                "  func.func @main(%arg0: !FHE.eint<{0}>, %arg1: i{1}) -> !FHE.eint<{0}> {{".format(p, integer_bitwidth))
            print(
                "    %1 = \"FHE.sub_eint_int\"(%arg0, %arg1): (!FHE.eint<{0}>, i{1}) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
            print("    return %1: !FHE.eint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("    - scalar: {0}".format(max_value))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: 0")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("    - scalar: 0")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value - 1))
            print("    - scalar: {0}".format(max_value >> 1))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value >> 1))
            may_check_error_rate()
            print("---")
        # sub_int_eint_cst
        print("description: sub_int_eint_cst_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print("    %0 = arith.constant {1} : i{0}".format(
            integer_bitwidth, max_constant))
        print(
            "    %1 = \"FHE.sub_int_eint\"(%0, %arg0): (i{1}, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.eint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_constant))
        print("    outputs:")
        print("    - scalar: 0")
        may_check_error_rate()
        print("---")
        # sub_int_eint_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: sub_int_eint_arg_{0}bits".format(p))
            print("program: |")
            print(
                "  func.func @main(%arg0: i{1}, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p, integer_bitwidth))
            print(
                "    %1 = \"FHE.sub_int_eint\"(%arg0, %arg1): (i{1}, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
            print("    return %1: !FHE.eint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    - scalar: {0}".format(max_value))
            print("    outputs:")
            print("    - scalar: 0")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    - scalar: 0")
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value - 1))
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    - scalar: {0}".format(max_value >> 1))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value >> 1))
            may_check_error_rate()
            print("---")
        # sub_eint
        print("description: sub_eint_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print(
            "    %1 = \"FHE.sub_eint\"(%arg0, %arg1): (!FHE.eint<{0}>, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p))
        print("    return %1: !FHE.eint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("    - scalar: {0}".format(max_value))
        print("    outputs:")
        print("    - scalar: 0")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("    - scalar: 0")
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value - 1))
        print("    - scalar: {0}".format(max_value >> 1))
        print("    outputs:")
        print("    - scalar: {0}".format(max_value >> 1))
        may_check_error_rate()
        print("---")
        # mul_eint_int_cst
        print("description: mul_eint_int_cst_{0}bits".format(p))
        print("program: |")
        print(
            "  func.func @main(%arg0: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p))
        print("    %0 = arith.constant 2 : i{0}".format(integer_bitwidth))
        print(
            "    %1 = \"FHE.mul_eint_int\"(%arg0, %0): (!FHE.eint<{0}>, i{1}) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.eint<{0}>".format(p))
        print("  }")
        if p <= 57:
            print(f"p-error: {P_ERROR}")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: 0")
        print("    outputs:")
        print("    - scalar: 0")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value >> 1))
        print("    outputs:")
        print("    - scalar: {0}".format(max_value - 1))
        may_check_error_rate()
        print("---")
        # mul_eint_int_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: mul_eint_int_arg_{0}bits".format(p))
            print("program: |")
            print(
                "  func.func @main(%arg0: !FHE.eint<{0}>, %arg1: i{1}) -> !FHE.eint<{0}> {{".format(p, integer_bitwidth))
            print(
                "    %1 = \"FHE.mul_eint_int\"(%arg0, %arg1): (!FHE.eint<{0}>, i{1}) -> (!FHE.eint<{0}>)".format(p, integer_bitwidth))
            print("    return %1: !FHE.eint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("    - scalar: 1")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            if not args.minimal:
                print("  - inputs:")
                print("    - scalar: 0")
                print("    - scalar: {0}".format(max_value))
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: 0")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("    - scalar: 0")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: 0")
                print("  - inputs:")
                print("    - scalar: 1")
                print("    - scalar: {0}".format(max_value))
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
            may_check_error_rate()
            print("---")
        # mul_eint
        # Disable precision on MacOS, since the instance used for by the CI
        # does not provide enough RAM for a precision higher than 8 bits
        if (platform.system() == "Darwin" and p <= 8) or \
           (platform.system() != "Darwin" and p <= 15):
            def gen_random_encodable():
                while True:
                    a = random.randint(1, max_value)
                    b = random.randint(1, max_value)
                    if a*b <= max_value:
                        return a, b

            print("description: mul_eint_{0}bits".format(p+1))
            print("program: |")
            print(
                "  func.func @main(%arg0: !FHE.eint<{0}>, %arg1: !FHE.eint<{0}>) -> !FHE.eint<{0}> {{".format(p+1))
            print(
                "    %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.eint<{0}>, !FHE.eint<{0}>) -> (!FHE.eint<{0}>)".format(p+1))
            print("    return %1: !FHE.eint<{0}>".format(p+1))
            print("  }")
            print("tests:")
            inp = gen_random_encodable()
            print("  - inputs:")
            print("    - scalar: {0}".format(inp[0]))
            print("    - scalar: {0}".format(inp[1]))
            print("    outputs:")
            print("    - scalar: {0}".format(inp[0]*inp[1]))
            if not args.minimal:
                print("  - inputs:")
                print("    - scalar: 0")
                print("    - scalar: 0")
                print("    outputs:")
                print("    - scalar: 0")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("    - scalar: 0")
                print("    outputs:")
                print("    - scalar: 0")
                print("  - inputs:")
                print("    - scalar: 0")
                print("    - scalar: {0}".format(max_value))
                print("    outputs:")
                print("    - scalar: 0")
                print("  - inputs:")
                print("    - scalar: 1")
                print("    - scalar: {0}".format(max_value))
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("    - scalar: 1")
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
            print("---")
    # signed
    for p in range(MIN_PRECISON, MAX_PRECISION+1):
        print("---")
        def may_check_error_rate():
            if p in PRECISIONS_WITH_ERROR_RATES:
                print(TEST_ERROR_RATES)
        min_value = -(2 ** (p - 1))
        max_value = abs(min_value) - 1
        integer_bitwidth = p + 1
        max_constant = min((2 ** (57-p)) - 1, max_value)

        # identity
        print("description: signed_identity_{0}bits".format(p))
        print("program: |")
        print("  func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
        print("    return %arg0: !FHE.esint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(min_value))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(min_value))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        may_check_error_rate()

        print("---")

        # zero_tensor
        print("description: signed_zero_tensor_{0}bits".format(p))
        print("program: |")
        print("  func.func @main() -> tensor<2x2x4x!FHE.esint<{0}>> {{".format(p))
        print("    %0 = \"FHE.zero_tensor\"() : () -> tensor<2x2x4x!FHE.esint<{0}>>".format(p))
        print("    return %0: tensor<2x2x4x!FHE.esint<{0}>>".format(p))
        print("  }")
        print("tests:")
        print("  - outputs:")
        print("    - tensor: [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]")
        print("      shape: [2,2,4]")
        print("      signed: true")
        may_check_error_rate()

        print("---")

        # add_eint_int_cst
        print("description: signed_add_eint_int_cst_{0}bits".format(p))
        print("program: |")
        print("  func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
        print("    %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
        print("    %1 = \"FHE.add_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.esint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(-1))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value-1))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(min_value))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(min_value + 1))
        print("      signed: true")
        may_check_error_rate()

        print("---")

        # add_eint_int_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: signed_add_eint_int_arg_{0}bits".format(p))
            print("program: |")
            print("  func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
            print("    %0 = \"FHE.add_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
            print("    return %0: !FHE.esint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(min_value))
            print("      signed: true")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(min_value))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(min_value))
            print("      signed: true")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(min_value + 1))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value-1))
            print("      signed: true")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(-1))
            print("      signed: true")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(-1))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(-1))
            print("      signed: true")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            may_check_error_rate()

        print("---")

        # add_eint
        print("description: signed_add_eint_{0}bits".format(p))
        print("program: |")
        print("  func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
        print("    %res = \"FHE.add_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> !FHE.esint<{0}>".format(p))
        print("    return %res: !FHE.esint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(((2 ** (p - 1)) >> 1) - 1))
        print("      signed: true")
        print("    - scalar: {0}".format(((2 ** (p - 1)) >> 1)))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(-1 if p == 1 else (2 ** (p - 1)) - 1))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(-(2 ** (p - 1))))
        print("      signed: true")
        print("    - scalar: {0}".format(((2 ** (p - 1)) - 1)))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(-1))
        print("      signed: true")
        may_check_error_rate()

        print("---")

        # sub_eint_int_cst
        print("description: signed_sub_eint_int_cst_{0}bits".format(p))
        print("program: |")
        print("  func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
        print("    %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
        print("    %1 = \"FHE.sub_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.esint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(max_value - 1))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(-1))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(min_value + 1))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(min_value))
        print("      signed: true")
        may_check_error_rate()

        print("---")

        # sub_eint_int_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: signed_sub_eint_int_arg_{0}bits".format(p))
            print("program: |")
            print("  func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
            print("    %1 = \"FHE.sub_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
            print("    return %1: !FHE.esint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value - 1))
            print("      signed: true")
            if p != 28:
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    - scalar: {0}".format(2 * max_value))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(-max_value))
                print("      signed: true")
            may_check_error_rate()

        print("---")

        # sub_int_eint_cst
        if p != 1:
            print("description: signed_sub_int_eint_cst_{0}bits".format(p))
            print("program: |")
            print("  func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
            print("    %0 = arith.constant 1 : i{0}".format(integer_bitwidth))
            print("    %1 = \"FHE.sub_int_eint\"(%0, %arg0): (i{1}, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
            print("    return %1: !FHE.esint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(min_value + 2))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(min_value + 2))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            may_check_error_rate()

        print("---")

        # sub_int_eint_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: signed_sub_int_eint_arg_{0}bits".format(p))
            print("program: |")
            print("  func.func @main(%arg0: i{1}, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
            print("    %1 = \"FHE.sub_int_eint\"(%arg0, %arg1): (i{1}, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
            print("    return %1: !FHE.esint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    - scalar: {0}".format(0))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(max_value - 1))
            print("      signed: true")
            if p != 28:
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    - scalar: {0}".format(2 * max_value))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(-max_value))
                print("      signed: true")
            may_check_error_rate()

        print("---")

        # sub_eint
        print("description: signed_sub_eint_{0}bits".format(p))
        print("program: |")
        print("  func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
        print("    %res = \"FHE.sub_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> !FHE.esint<{0}>".format(p))
        print("    return %res: !FHE.esint<{0}>".format(p))
        print("  }")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    - scalar: {0}".format(1))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(-1))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(min_value))
        print("      signed: true")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(min_value))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    - scalar: {0}".format(min_value + 1))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("  - inputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    - scalar: {0}".format(max_value))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(min_value + 1))
        print("      signed: true")
        may_check_error_rate()

        print("---")

        # mul_eint_int_cst
        print("description: signed_mul_eint_int_cst_{0}bits".format(p))
        print("program: |")
        print("  func.func @main(%arg0: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p))
        print("    %0 = arith.constant 2 : i{0}".format(integer_bitwidth))
        print("    %1 = \"FHE.mul_eint_int\"(%arg0, %0): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
        print("    return %1: !FHE.esint<{0}>".format(p))
        print("  }")
        if p <= 57:
            print(f"p-error: {P_ERROR}")
        print("tests:")
        print("  - inputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        print("    outputs:")
        print("    - scalar: {0}".format(0))
        print("      signed: true")
        if p != 1:
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value // 2))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(max_value - 1))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(min_value // 2))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(min_value))
            print("      signed: true")
        may_check_error_rate()

        print("---")

        # mul_eint_int_arg
        if p <= 28:
            # above 28 bits the *arg test doesn't have solution
            # TODO: Make a test that test that
            print("description: signed_mul_eint_int_arg_{0}bits".format(p))
            print("program: |")
            print("  func.func @main(%arg0: !FHE.esint<{0}>, %arg1: i{1}) -> !FHE.esint<{0}> {{".format(p, integer_bitwidth))
            print("    %0 = \"FHE.mul_eint_int\"(%arg0, %arg1): (!FHE.esint<{0}>, i{1}) -> (!FHE.esint<{0}>)".format(p, integer_bitwidth))
            print("    return %0: !FHE.esint<{0}>".format(p))
            print("  }")
            print("tests:")
            print("  - inputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(max_value))
            print("      signed: true")
            print("  - inputs:")
            print("    - scalar: {0}".format(1))
            print("      signed: true")
            print("    - scalar: {0}".format(min_value))
            print("      signed: true")
            print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
            print("    outputs:")
            print("    - scalar: {0}".format(min_value))
            print("      signed: true")
            if not args.minimal:
                print("  - inputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(0))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("    - scalar: {0}".format(1))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(1))
                print("      signed: true")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    - scalar: {0}".format(-1))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(min_value + 1))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(-1))
                print("      signed: true")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(min_value + 1))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(min_value + 1))
                print("      signed: true")
                print("    - scalar: {0}".format(-1))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(-1))
                print("      signed: true")
                print("    - scalar: {0}".format(min_value + 1))
                print("      signed: true")
                print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                if p > 2:
                    print("  - inputs:")
                    print("    - scalar: {0}".format(3))
                    print("      signed: true")
                    print("    - scalar: {0}".format(1))
                    print("      signed: true")
                    print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                    print("    outputs:")
                    print("    - scalar: {0}".format(3))
                    print("      signed: true")
                    print("  - inputs:")
                    print("    - scalar: {0}".format(3))
                    print("      signed: true")
                    print("    - scalar: {0}".format(-1))
                    print("      signed: true")
                    print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                    print("    outputs:")
                    print("    - scalar: {0}".format(-3))
                    print("      signed: true")
                    print("  - inputs:")
                    print("    - scalar: {0}".format(-3))
                    print("      signed: true")
                    print("    - scalar: {0}".format(1))
                    print("      signed: true")
                    print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                    print("    outputs:")
                    print("    - scalar: {0}".format(-3))
                    print("      signed: true")
                    print("  - inputs:")
                    print("    - scalar: {0}".format(-3))
                    print("      signed: true")
                    print("    - scalar: {0}".format(-1))
                    print("      signed: true")
                    print("      width: {0}".format(normalize_integer_precision(integer_bitwidth)))
                    print("    outputs:")
                    print("    - scalar: {0}".format(3))
                    print("      signed: true")
            may_check_error_rate()

        print("---")
        # mul_eint
        if (platform.system() == "Darwin" and 2 <= p <= 8) or \
           (platform.system() != "Darwin" and 2 <= p <= 15):
            def gen_random_encodable(p):
                while True:
                    a = random.randint(min_value, max_value)
                    b = random.randint(min_value, max_value)
                    if min_value <= a*b <= max_value:
                        if p == 3:
                            return a, b
                        if not (a in [-1, 1, 0] or b in [-1, 1, 0]):
                            return a, b

            print("description: signed_mul_eint_{0}bits".format(p+1))
            print("program: |")
            print(
                "  func.func @main(%arg0: !FHE.esint<{0}>, %arg1: !FHE.esint<{0}>) -> !FHE.esint<{0}> {{".format(p+1))
            print(
                "    %1 = \"FHE.mul_eint\"(%arg0, %arg1): (!FHE.esint<{0}>, !FHE.esint<{0}>) -> (!FHE.esint<{0}>)".format(p+1))
            print("    return %1: !FHE.esint<{0}>".format(p+1))
            print("  }")
            print("tests:")
            inp = gen_random_encodable(p+1)
            print("  - inputs:")
            print("    - scalar: {0}".format(inp[0]))
            print("      signed: true")
            print("    - scalar: {0}".format(inp[1]))
            print("      signed: true")
            print("    outputs:")
            print("    - scalar: {0}".format(inp[0]*inp[1]))
            print("      signed: true")
            if not args.minimal:
                print("  - inputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("    - scalar: 0")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    - scalar: 0")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("    - scalar: 0")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: 0")
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: 1")
                print("      signed: true")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: 1")
                print("      signed: true")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    - scalar: 1")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("    - scalar: 1")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(min_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: -1")
                print("      signed: true")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(-max_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: -1")
                print("      signed: true")
                print("    - scalar: {0}".format(min_value+1))
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(-(min_value+1)))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(max_value))
                print("      signed: true")
                print("    - scalar: -1")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(-max_value))
                print("      signed: true")
                print("  - inputs:")
                print("    - scalar: {0}".format(min_value+1))
                print("      signed: true")
                print("    - scalar: -1")
                print("      signed: true")
                print("    outputs:")
                print("    - scalar: {0}".format(-(min_value+1)))
                print("      signed: true")
            print("---")

if __name__ == "__main__":
    CLI = argparse.ArgumentParser()
    CLI.add_argument(
        "--minimal",
        help="Specify whether to generate minimal tests only",
        type=bool,
        default=False,
    )
    main(CLI.parse_args())
